from numpy import block, isin
import torch
import torch.nn as nn
from models.block.DepthSeperable import DepthSeperabelConv2d
from models.utils.utils import get_activation_function, make_conv_block


class MobileNetv1(nn.Module):
    # (128,2) means conv planes=128, conv stride=2, by default conv stride=1
    cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]

    def __init__(self, activation_type, num_classes=10, oper_order='cba', dataset='cifar10', alpha=1.0,
                 cut_block=None, additional_cut=[2,4], depthwise_acti=True):
        super(MobileNetv1, self).__init__()
        self.activation_generator = get_activation_function(activation_type)
        self.oper_order = oper_order
        self.alpha = alpha
        self.cutted_resolution = 1

        if 'cifar' in dataset or 'tinyImageNet' in dataset:
            stride = 1
        else:
            stride = 2
        self.stem = make_conv_block(3, int(32*self.alpha), kernel_size=3, stride=stride, padding=1,
                                    activation_generator=self.activation_generator, oper_order=self.oper_order)

        if cut_block is not None:
            # stride가 2인 블럭 자르면 마지막 average pooling kernel size * 2
            cutted = self.cfg[-int(cut_block):]
            self.cfg = self.cfg[:-int(cut_block)]

            if additional_cut is not None:
                for ind in additional_cut:
                    if isinstance(self.cfg[ind], tuple):
                        self.cutted_resolution *= 2

                    del self.cfg[ind]

            for cutted_block in cutted:
                if isinstance(cutted_block, tuple):
                    self.cutted_resolution *= 2

        self.layers = self._make_layers(in_planes=32, depthwise_acti=depthwise_acti)
        last_channel = self.cfg[-1] if not isinstance(self.cfg[-1], tuple) else self.cfg[-1][0]

        if 'cifar' in dataset:
            self.avgpool = nn.AvgPool2d(2 * self.cutted_resolution)
        elif 'tinyImageNet' == dataset:
            self.avgpool = nn.AvgPool2d(2 * self.cutted_resolution)
            last_channel *= (2 * 2)
        elif dataset == 'ImageNet' or dataset == 'cub200':
            self.avgpool = nn.AvgPool2d(7 * self.cutted_resolution)

        self.linear = nn.Linear(int(last_channel * self.alpha), num_classes)

        print(self.stem)
        for ind, st in enumerate(self.stem):
            print(st)

    def _make_layers(self, in_planes, depthwise_acti):
        layers = []
        for x in self.cfg:
            out_planes = x if isinstance(x, int) else x[0]
            stride = 1 if isinstance(x, int) else x[1]
            layers.append(DepthSeperabelConv2d(int(in_planes*self.alpha), int(out_planes*self.alpha), kernel_size=3,
                                               stride=stride, padding=1,
                                               activation_generator=self.activation_generator,
                                               oper_order=self.oper_order, depthwise_acti=depthwise_acti))
            in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.stem(x)
        out = self.layers(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)

        out = self.linear(out)
        return out

    def get_minmax(self, x, block_output=False, channel_flag=False):
        minmax = []
        def _get_channel_minmax(x):
            cbhw = x.transpose(0, 1)
            flatten_channel = cbhw.reshape(cbhw.size(0), -1)

            c_max, _ = torch.max(flatten_channel, dim=1)
            c_min, _ = torch.min(flatten_channel, dim=1)

            channel_minmax = torch.stack([c_min, c_max], dim=1)

            return channel_minmax

        def _get_layer_minmax(x):
            layer_minmax = torch.stack([x.min(), x.max()], dim=0)
            print(layer_minmax.shape)

            return layer_minmax
        # (stem) 1 + (layers) 13 + (linear) 1 = 15
        if block_output:
            # (stem)
            x = self.stem(x)
            
            """
            x = torch.Tensor([
                [[[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]], [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]],
                [[[0.1, 0.2, -0.3], [0.1, 0.2, 0.3]], [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]]]
            ])
            print(x.shape)
            """
            
            if channel_flag:
                channel_minmax = _get_channel_minmax(x)
                minmax.append(channel_minmax)
            else:
                layer_minmax = _get_layer_minmax(x)
                minmax.append(layer_minmax)
            """
            print(channel_minmax)
            exit()
            """
            
            # (layers)
            for idx_, module_ in enumerate(self.layers):
                # for module in module_:
                x = module_(x)
                if channel_flag:
                    channel_minmax = _get_channel_minmax(x)
                    minmax.append(channel_minmax)
                else:
                    layer_minmax = _get_layer_minmax(x)
                    minmax.append(layer_minmax)
            
            # (avgpool)
            if self.avgpool is not None:
                x = self.avgpool(x)

            x = x.view(x.size(0), -1)

            # (linear)
            x = self.linear(x)
            if channel_flag:
                channel_minmax = _get_channel_minmax(x)
                minmax.append(channel_minmax)
            else:
                layer_minmax = _get_layer_minmax(x)
                minmax.append(layer_minmax) 
        
        # get minmax of activation function's output
        # (stem) 1 + (layers) 13 * 2 = 27
        else:
            acti_type = self.activation_generator.__next__()
            operation_cnt = 0

            # (stem)
            for idx, module in enumerate(self.stem):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)
                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)
                        minmax.append(layer_minmax)
                else:
                    x = module(x)
            """
            for idx, module in enumerate(self.stem.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)
                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)
                        minmax.append(layer_minmax)

                elif not isinstance(module, nn.Sequential):
                    x = module(x)
            """
            

            # (layers)
            for idx_, module_ in enumerate(self.layers):                  
                for idx, module in enumerate(module_.depthwise):
                    if isinstance(module, type(acti_type)):
                        x = module(x)
                        if channel_flag:
                            channel_minmax = _get_channel_minmax(x)
                            minmax.append(channel_minmax)
                        else:
                            layer_minmax = _get_layer_minmax(x)
                            minmax.append(layer_minmax)
                    else:
                        x = module(x) 

                for idx, module in enumerate(module_.pointwise):
                    if isinstance(module, type(acti_type)):
                        x = module(x)
                        if channel_flag:
                            channel_minmax = _get_channel_minmax(x)
                            minmax.append(channel_minmax)
                        else:
                            layer_minmax = _get_layer_minmax(x)
                            minmax.append(layer_minmax)
                    else:
                        x = module(x)
            """
            for idx, module in enumerate(self.layers.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    if channel_flag:
                        channel_minmax = _get_channel_minmax(x)
                        minmax.append(channel_minmax)
                    else:
                        layer_minmax = _get_layer_minmax(x)
                        minmax.append(layer_minmax)

                elif (not isinstance(module, nn.Sequential)) and (not isinstance(module, DepthSeperabelConv2d)):
                    x = module(x)
            """
            
        return minmax

    def get_activation(self, x, block_output=False):
        features = []
        # (stem) 1 + (layers) 13 = 14
        if block_output:
            x = self.stem(x)
            features.append(x)
            
            for idx_, module_ in enumerate(self.layers):
                # for module in module_:
                x = module_(x)
                features.append(x)

        # get activation function's output
        # (stem) 1 + (layers) 13 * 2 = 27
        else:
            acti_type = self.activation_generator.__next__()
            
            for idx, module in enumerate(self.stem):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)
                else:
                    x = module(x)
            """
            for idx, module in enumerate(self.stem.modules()):
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)
                elif not isinstance(module, nn.Sequential):
                    x = module(x)
            """
            for idx_, module_ in enumerate(self.layers):                
                for idx, module in enumerate(module_.depthwise):
                    if isinstance(module, type(acti_type)):
                        x = module(x)
                        features.append(x)
                    else:
                        x = module(x) 

                for idx, module in enumerate(module_.pointwise):
                    if isinstance(module, type(acti_type)):
                        x = module(x)
                        features.append(x)
                    else:
                        x = module(x)

            
            """
            for idx, module in enumerate(self.layers.modules()): 
                # depthwise, pointwise 둘다 activ하나씩 가지고 있어서 activ를 잡기만 하면 됨?
                if isinstance(module, type(acti_type)):
                    x = module(x)
                    features.append(x)
                elif (not isinstance(module, nn.Sequential)) and (not isinstance(module, DepthSeperabelConv2d)):
                    x = module(x)
            """ 
        if self.avgpool is not None:
            x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        logit = self.linear(x)

        return features, logit

#############################################################
import numpy as np
def check(data):
    print(len(data))

    for block in data:
        print(block.shape)

        # activation output인지 BN output인지 확인
        print(block.min(), block.max())


def test():
    net = MobileNetv1(activation_type='tanh', oper_order='cab') # oper_order='cba' or 'cab'
    x = torch.randn(2,3,32,32)
    y = net(x)
    print("logit size = ",y.size())

    block_output, logit1 = net.get_activation(x, block_output=True) # block_output, logit = net.get~~~
    activation_output, logit2 = net.get_activation(x, block_output=False)
    
    print()
    print("--block output--")
    check(block_output)
    print()
    print("--activation output--")
    check(activation_output)

    block_minmax = net.get_minmax(x, block_output=True, channel_flag=True)
    activation_minmax = net.get_minmax(x, block_output=False, channel_flag=True)

    print()
    print("--block minmax--")
    check(block_minmax)
    print()
    print("--activation minmax--")
    check(activation_minmax)

    print()
    print("--logit and real logit--")
    print('logit1 =', logit1)
    print('logit2 =', logit2)
    print('real logit =', y)
    print((logit1 == logit2) & (logit2 == y))

    exit()
        

"""
if __name__ == '__main__':
model = MobileNetv1(activation_type="tanh", num_classes=100,
                            oper_order="cab", dataset="cifar100", alpha=1,
                            cut_block=None)
"""
# test()